import torch

from torch import nn
from Utils import Misc
from knn_cuda import KNN
from einops import rearrange

from Utils.Registry import Registry
from Utils.Tool import index_points, SwapAxes
from Utils.pointnet_util import PointNetFeaturePropagation

transition = Registry('Transition')


class Group(nn.Module):
    def __init__(self, group_size, num_group=None):
        super().__init__()
        self.group_size = group_size
        self.num_group = num_group
        self.knn = KNN(k=self.group_size, transpose_mode=True)

    def forward(self, xyz):
        '''
            input: B N 3
            ---------------------------
            output: B G M 3
            center : B G 3
        '''
        batch_size, num_points, _ = xyz.shape
        # fps the centers out
        center_idx, center = Misc.fps(xyz.contiguous(), self.num_group)  # B G 3
        # knn to get the neighborhood
        # import ipdb; ipdb.set_trace()
        # idx = knn_query(xyz, center, self.group_size)  # B G M
        _, idx = self.knn(xyz, center)  # B G M
        assert idx.size(1) == self.num_group
        assert idx.size(2) == self.group_size
        return idx, center_idx.type(torch.int64), center


class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.first_conv = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, 1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, 1)
        )
        self.second_conv = nn.Sequential(
            nn.Conv1d(out_channels * 2, out_channels, 1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, 1)
        )

    def forward(self, inputs):
        '''
            features : B G N D
            -----------------
            feature_global : B G C
        '''
        b, g, n, d = inputs.shape
        inputs = rearrange(inputs, 'b g n d -> (b g) n d')
        # encoder
        feature = self.first_conv(inputs.transpose(2, 1))
        feature_global = torch.max(feature, dim=2, keepdim=True)[0]
        feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1)
        feature = self.second_conv(feature)
        feature_global = torch.max(feature, dim=2, keepdim=False)[0]
        return feature_global.reshape(b, g, -1)


@transition.register_module('FPS_Compatibility')
class FPS_Compatibility(nn.Module):
    def __init__(self, in_channels, out_channels, group_size, stride):
        super(FPS_Compatibility, self).__init__()
        self.group_size = group_size
        self.stride = stride
        self.knn = KNN(k=self.group_size, transpose_mode=True)
        self.encoder = Encoder(in_channels, out_channels)

    def forward(self, xyz, features):
        batch_size, num_points, _ = xyz.shape
        # fps the centers out trans_down_cfgs.num_group //= stride[i - 1]
        center_idx, center = Misc.fps(xyz.contiguous(), num_points // self.stride)  # B G 3
        # knn to get the neighborhood
        # import ipdb; ipdb.set_trace()
        # idx = knn_query(xyz, center, self.group_size)  # B G M
        _, idx = self.knn(xyz, center)  # B G M
        assert idx.size(1) == num_points // self.stride
        assert idx.size(2) == self.group_size

        # center point features plus pos info
        group_input = self.encoder(index_points(features, idx))
        return center_idx.type(torch.int64), center, group_input


@transition.register_module('FarthestSampling')
class FarthestSampling(nn.Module):
    def __init__(self, in_channels, out_channels, group_size, num_group):
        super(FarthestSampling, self).__init__()
        self.group_divider = Group(group_size=group_size, num_group=num_group)
        self.encoder = Encoder(in_channels, out_channels)

    def forward(self, xyz, features):
        idx, center_idx, center = self.group_divider(xyz)
        # center point features plus pos info
        group_input = self.encoder(index_points(features, idx))
        return center_idx, center, group_input


@transition.register_module('Interpolation')
class Interpolation(nn.Module):
    def __init__(self, in_channels_dec, in_channels_enc, out_channels):
        super(Interpolation, self).__init__()

        self.fc1 = nn.Sequential(
            SwapAxes(),
            nn.Conv1d(in_channels_dec, out_channels, 1),
            nn.BatchNorm1d(out_channels),  # TODO
            nn.ReLU(inplace=True),
            SwapAxes(),
        )
        self.fc2 = nn.Sequential(
            SwapAxes(),
            nn.Conv1d(in_channels_enc, out_channels, 1),
            nn.BatchNorm1d(out_channels),  # TODO
            nn.ReLU(inplace=True),
            SwapAxes(),
        )
        self.fp = PointNetFeaturePropagation(-1, [])

    def forward(self, xyz1, points1, xyz2, points2):
        # xyz1 less xyz2
        feats1 = self.fc1(points1)
        feats1 = self.fp(xyz2.transpose(1, 2), xyz1.transpose(1, 2), None, feats1.transpose(1, 2)).transpose(1, 2)
        feats2 = self.fc2(points2)
        return xyz2, feats1 + feats2

